package sqlitegorm

import (
	"database/sql"
	"errors"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"
	"unsafe"

	Log "github.com/cihub/seelog"
	"github.com/shopspring/decimal"
)

// // 将查询结果转换成map数组,常用于原生sql查询
// func ScanRows2map(rows *sql.Rows) []map[string]string {
// 	if nil == rows {
// 		return nil
// 	}

// 	res := make([]map[string]string, 0)               //  定义结果 map
// 	colTypes, _ := rows.ColumnTypes()                 // 列信息
// 	var rowParam = make([]interface{}, len(colTypes)) // 传入到 rows.Scan 的参数 数组
// 	var rowValue = make([]interface{}, len(colTypes)) // 接收数据一行列的数组

// 	for i, colType := range colTypes {
// 		rowValue[i] = reflect.New(colType.ScanType())           // 跟据数据库参数类型，创建默认值 和类型
// 		rowParam[i] = reflect.ValueOf(&rowValue[i]).Interface() // 跟据接收的数据的类型反射出值的地址
// 	}

// 	// 遍历每行
// 	for rows.Next() {
// 		//rows.Scan(rowParam) // 赋值到 rowValue 中 go 1.20
// 		rows.Scan(rowParam...) // 赋值到 rowValue 中  go 1.16
// 		record := make(map[string]string)
// 		for i, colType := range colTypes {
// 			if rowValue[i] == nil {
// 				record[colType.Name()] = ""
// 				continue
// 			}

// 			//如果字段类型为int,则需要进一步判断
// 			//并且1.如果获得值类型为int64,则需要按int64处理,常用于类似以下的查询:
// 			//rows, _ := SqlFactory{}.GetDB().Raw("select * from table where uId=@GuId and sName=@GsName", &where).Rows()
// 			//并且2.如果数据库类型虽然为INT,但获取的值被以string进行接收, 则要按字符串的方式进行,常用于类似以下的查询:
// 			//rows, _ := SqlFactory{}.GetDB().Raw("select * from Student").Rows()
// 			if colType.DatabaseTypeName() == "INT" { //
// 				switch rowValue[i].(type) {
// 				case int64: //
// 					record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// 					continue
// 				}
// 			}

// 			if colType.DatabaseTypeName() == "BIGINT" {
// 				record[colType.Name()] = toStr(rowValue[i])
// 				continue
// 			}

// 			if colType.DatabaseTypeName() == "FLOAT" {
// 				record[colType.Name()] = toStr(rowValue[i])
// 				continue
// 			}

// 			if colType.DatabaseTypeName() == "DOUBLE" {
// 				record[colType.Name()] = toStr(rowValue[i])
// 				continue
// 			}

// 			if colType.DatabaseTypeName() == "DECIMAL" {
// 				record[colType.Name()] = toStr(rowValue[i])
// 				continue
// 			}

// 			if colType.DatabaseTypeName() == "DATETIME" {
// 				record[colType.Name()] = rowValue[i].(time.Time).Format("2006-01-02 15:04:05")
// 				continue
// 			}

// 			record[colType.Name()] = byte2Str(rowValue[i].([]byte))
// 		}
// 		res = append(res, record)
// 	}

// 	return res
// }

// 将查询结果转换成map数组,常用于原生sql查询
func ScanRows2mapI(rows *sql.Rows) []map[string]interface{} {
	defer rows.Close()

	columns, err := rows.Columns()
	if err != nil {
		return nil
	}

	values := make([]interface{}, len(columns))
	valuePtrs := make([]interface{}, len(columns))
	for i := range columns {
		valuePtrs[i] = &values[i]
	}

	colNames := make(map[string]string)
	temp, _ := rows.ColumnTypes() // 列信息
	for _, colType := range temp {
		colNames[colType.Name()] = strings.ToUpper(colType.DatabaseTypeName()) // 跟据数据库参数类型，创建默认值 和类型
	}

	maps := []map[string]interface{}{} //就算没数据也不会返回nil
	for rows.Next() {
		err := rows.Scan(valuePtrs...)
		if err != nil {
			return nil
		}

		m := make(map[string]interface{})
		for i, col := range columns {
			val := values[i]
			m[col] = val //先设置为从数据库获取的值

			if val == nil {
				continue
			}

			//数据调整,因从数据库获取的值并不一定符合数据类型,例如Decimal字段返回的数据类型不是decimal.Decimal而是string
			switch colNames[col] {
			case "INT":
				value := toInt(val, -99999)
				if value != -99999 {
					m[col] = value
				}
			case "NUMERIC":
				value := toInt(val, -99999)
				if value != -99999 {
					m[col] = value
				}
			case "TINYINT":
				value := toInt(val, -99999)
				if value != -99999 {
					m[col] = value
				}
			case "BIGINT":
				value := toInt(val, -99999)
				if value != -99999 {
					m[col] = value
				}
			case "FLOAT":
				value := toFloat(val, 64, -99999999.99999999)
				if value != -99999999.99999999 {
					m[col] = value
				}
			case "DOUBLE":
				value := toFloat(val, 64, -99999999.99999999)
				if value != -99999999.99999999 {
					m[col] = value
				}
			case "DECIMAL":
				if reflect.TypeOf(val).Elem().Name() == "Decimal" {
					m[col] = val.(*decimal.Decimal)
					continue
				}

				if reflect.TypeOf(val).Elem().Name() == "uint8" {
					m[col] = toStr(val)
					continue
				}

				m[col] = val
			case "DATETIME":
				m[col] = val.(time.Time).Format("2006-01-02 15:04:05")
			case "DATE":
				m[col] = val.(time.Time).Format("2006-01-02 15:04:05")
			case "TIMESTAMP":
				m[col] = toInt(val, -99999)
			default:
				b, ok := val.([]byte)
				if ok {
					m[col] = string(b)
					continue
				}

				m[col] = val
			}
		}

		maps = append(maps, m)
	}

	return maps
}

// func ScanRows2mapI(rows *sql.Rows) []map[string]interface{} {
// 	if nil == rows {
// 		return nil
// 	}

// 	res := make([]map[string]interface{}, 0)          //  定义结果 map
// 	colTypes, _ := rows.ColumnTypes()                 // 列信息
// 	var rowParam = make([]interface{}, len(colTypes)) // 传入到 rows.Scan 的参数 数组
// 	var rowValue = make([]interface{}, len(colTypes)) // 接收数据一行列的数组

// 	for i, colType := range colTypes {
// 		rowValue[i] = reflect.New(colType.ScanType())           // 跟据数据库参数类型，创建默认值 和类型
// 		rowParam[i] = reflect.ValueOf(&rowValue[i]).Interface() // 跟据接收的数据的类型反射出值的地址
// 	}

// 	// 遍历
// 	for rows.Next() {
// 		rows.Scan(rowParam) // 赋值到 rowValue 中 go 1.20
// 		//rows.Scan(rowParam...) // 赋值到 rowValue 中  go 1.16
// 		record := make(map[string]interface{})
// 		for i, colType := range colTypes {
// 			if rowValue[i] == nil {
// 				record[colType.Name()] = ""
// 				continue
// 			}

// 			//如果字段类型为int,则需要进一步判断
// 			//并且1.如果获得值类型为int64,则需要按int64处理,常用于类似以下的查询:
// 			//rows, _ := SqlFactory{}.GetDB().Raw("select * from table where uId=@GuId and sName=@GsName", &where).Rows()
// 			//并且2.如果数据库类型虽然为INT,但获取的值被以string进行接收, 则要按字符串的方式进行,常用于类似以下的查询:
// 			//rows, _ := SqlFactory{}.GetDB().Raw("select * from Student").Rows()
// 			if colType.DatabaseTypeName() == "INT" { //
// 				switch value := rowValue[i].(type) {
// 				case int64:
// 					record[colType.Name()] = rowValue[i]
// 					//record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// 					continue
// 				case string:
// 					record[colType.Name()] = toInt64(rowValue[i], -99999)
// 					continue
// 				case []uint8:
// 					record[colType.Name()] = toInt64(rowValue[i], -99999)
// 					continue
// 				default:
// 					fmt.Println(value)
// 				}
// 			}

// 			if colType.DatabaseTypeName() == "BIGINT" {
// 				switch value := rowValue[i].(type) {
// 				case int64:
// 					record[colType.Name()] = rowValue[i]
// 					//record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// 					continue
// 				case string:
// 					record[colType.Name()] = toInt64(rowValue[i], -99999)
// 					continue
// 				case []uint8:
// 					record[colType.Name()] = toInt64(rowValue[i], -99999)
// 					continue
// 				default:
// 					fmt.Println(value)
// 				}
// 			}

// 			if colType.DatabaseTypeName() == "DATETIME" {
// 				record[colType.Name()] = rowValue[i]
// 				continue
// 			}

// 			record[colType.Name()] = byte2Str(rowValue[i].([]byte))
// 		}
// 		res = append(res, record)
// 	}

// 	return res
// }

// Byte转Str
func byte2Str(b []byte) string {
	return *(*string)(unsafe.Pointer(&b))
}

// 转换字符串
func toStr(data interface{}) string {
	switch obj := data.(type) {
	case []uint8:
		return byte2Str(obj)
	default:
		return fmt.Sprintf("%v", data)
	}
}

// // 对象(字符串)转64整型
// func toInt64(data interface{}, iDefault int64) int64 {
// 	var str string
// 	switch obj := data.(type) {
// 	case []uint8:
// 		str = byte2Str(obj)
// 	default:
// 		str = fmt.Sprintf("%v", obj)
// 	}

// 	if str == "" { //字符串不能判断nil
// 		return iDefault
// 	}

// 	result, err := strconv.ParseInt(str, 10, 64)
// 	if err != nil {
// 		return iDefault
// 	}

// 	return result
// }

// 对象(字符串)转整型
func toInt(data interface{}, iDefault int) int {
	var str string
	switch obj := data.(type) {
	case []uint8:
		str = byte2Str(obj)
	default:
		str = fmt.Sprintf("%v", obj)
	}

	if str == "" { //字符串不能判断nil
		return iDefault
	}

	result, err := strconv.ParseInt(str, 10, 64)
	if err != nil {
		return iDefault
	}

	return int(result)
}

// 对象(字符串)转64整型
func toFloat(data interface{}, bitSize int, iDefault float64) float64 {
	var str string
	switch obj := data.(type) {
	case []uint8:
		str = byte2Str(obj)
	default:
		str = fmt.Sprintf("%v", obj)
	}

	if str == "" { //字符串不能判断nil
		return iDefault
	}

	result, err := strconv.ParseFloat(str, bitSize)
	if err != nil {
		return iDefault
	}

	return result
}

// 取数据库名称
func GetDbName(name string) string {
	return GetVariable(name)
}

// 取数据库全局变量
func GetVariable(name string) string {
	if name == "" {
		return ""
	}

	for key := range dbVariables {
		if name == key {
			return dbVariables[key]
		}
	}

	return ""
}

// 替换字符串中的所有全局变量
func ReplaceVariable(sqlstr string) string {
	if sqlstr == "" {
		return ""
	}

	result := sqlstr
	for key, val := range dbVariables {
		if !strings.Contains(result, "${") {
			return result
		}

		if !strings.Contains(result, "${"+key+"}") {
			continue
		}

		result = strings.Replace(result, "${"+key+"}", val, -1)
	}

	return result
}

// 添加记录,返回影响行数及错误信息
func Add(entity interface{}) (int64, error) {
	result := GetDB().Create(entity)

	return result.RowsAffected, result.Error
}

// 调用数据查询数量
func Count(sql string, params ...interface{}) (int, error) {
	var iCount int
	dbResult := doDb(sql, params, globGormDB.Raw).Scan(&iCount)
	if dbResult.Error != nil {
		return 0, dbResult.Error
	}

	return iCount, nil
}

// 调用数据查询
func Query(sql string, dest interface{}, where ...interface{}) (interface{}, error) {
	var dbResult GormDB
	if len(where) < 1 {
		dbResult = GetDB().Raw(sql).Scan(dest)
	} else {
		dbResult = GetDB().Raw(sql, where...).Scan(dest)
	}

	if dbResult.Error != nil {
		Log.Error("查询发生异常:", dbResult.Error)
		return nil, dbResult.Error
	}

	return &dest, dbResult.Error
}

// 调用数据查询
func Find(sql string, params ...interface{}) (tx GormDB) {
	return doDb(sql, params, globGormDB.Raw)
}

// 调用数据查询
// 返回规则: 发生错误时: {数据为nil,错误码值,错误信息}
// 正确时: {数据,数据数量,nil}
func FindToMap(text string, params ...interface{}) ([]map[string]interface{}, int, error) {
	rows, err := doDb(text, params, globGormDB.Raw).Rows()
	if err != nil {
		Log.Error("查询发生异常:", err)
		return nil, 1002, err
	}
	defer rows.Close()

	res := ScanRows2mapI(rows)
	if res == nil {
		Log.Error("查询成功后进行数据转换时发生异常,:无法正确转换")
		return nil, 1003, errors.New("查询发生异常")
	}

	rowCount := len(res)
	if rowCount < 1 {
		return res, rowCount, nil //没有数据
	}

	return res, rowCount, nil
}

// 调用数据查询一条记录
// 返回规则: 发生错误时: {数据为nil,错误码值,错误信息}
// 正确时: {数据,数据数量,nil}
func FindOneMap(text string, params ...interface{}) (map[string]interface{}, int, error) {
	rows, err := doDb(text, params, globGormDB.Raw).Rows()
	if err != nil {
		Log.Error("查询发生异常:", err)
		return nil, 1001, err
	}
	defer rows.Close()

	res := ScanRows2mapI(rows)
	if res == nil {
		Log.Error("查询成功后进行数据转换时发生异常,:无法正确转换")
		return nil, 1002, errors.New("查询发生异常")
	}

	rowCount := len(res)
	if rowCount < 1 {
		return nil, 1003, errors.New("没有数据")
	}

	return res[0], rowCount, nil
}

// 格式化结果集(备用代码)
// func FormatScan(results *[]map[string]interface{}, colTypes map[string]*sql.ColumnType) {
// 	for _, row := range *results {
// 		for key, value := range row {
// 			if value == nil {
// 				continue
// 			}

// 			switch colTypes[key].DatabaseTypeName() {
// 			case "INT":
// 				if reflect.TypeOf(value).String() == "int32" {
// 					row[key] = int(value.(int32))
// 					continue
// 				}

// 				if reflect.TypeOf(value).String() == "int64" {
// 					row[key] = int(value.(int64))
// 					continue
// 				}

// 				tmp := toInt(value, -99999)
// 				if tmp != -99999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "NUMERIC":
// 				tmp := toInt64(value, -99999)
// 				if tmp != -99999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "TINYINT":
// 				tmp := toInt64(value, -99999)
// 				if tmp != -99999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "BIGINT":
// 				tmp := toInt64(value, -99999)
// 				if tmp != -99999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "FLOAT":
// 				tmp := toFloat(value, 64, -99999999.99999999)
// 				if tmp != -99999999.99999999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "DOUBLE":
// 				tmp := toFloat(value, 64, -99999999.99999999)
// 				if tmp != -99999999.99999999 {
// 					row[key] = tmp
// 				}

// 				continue
// 			case "DECIMAL":
// 				if reflect.TypeOf(value).Elem().Name() == "Decimal" {
// 					row[key] = value.(*decimal.Decimal)
// 					continue
// 				}

// 				if reflect.TypeOf(value).Elem().Name() == "uint8" {
// 					row[key] = toStr(value)
// 					continue
// 				}

// 				continue
// 			case "DATETIME":
// 				row[key] = value.(time.Time).Format("2006-01-02 15:04:05")
// 				continue
// 			case "DATE":
// 				row[key] = value.(time.Time).Format("2006-01-02 15:04:05")
// 				continue
// 			case "TIME":
// 				row[key] = value.(time.Time).Format("15:04:05")
// 				continue
// 			case "TIMESTAMP":
// 				row[key] = toInt64(value, -99999)
// 				continue
// 			}
// 		}
// 	}
// }

// 调用数据查询
func Raw(sql string, params ...interface{}) (tx GormDB) {
	return doDb(sql, params, globGormDB.Raw)
}

// 调用数据查询
func RawRows(sql string, params ...interface{}) (*sql.Rows, error) {
	return doDb(sql, params, globGormDB.Raw).Rows()
}

// 调用数据更新
func Exec(sql string, params ...interface{}) (tx GormDB) {
	return doDb(sql, params, globGormDB.Exec)
}

// 调用数据库操作
func doDb(sql string, param []interface{}, dbFunc func(sql string, values ...interface{}) (tx GormDB)) (tx GormDB) {
	if (nil == param) || (len(param) < 1) {
		return dbFunc(sql)
	}

	iCount := len(param)
	if iCount > 1 {
		return dbFunc(sql, param...)
	}

	rtk := reflect.TypeOf(param[0]).Kind()
	if rtk == reflect.Map {
		s := reflect.ValueOf(param[0])
		if s.Len() < 1 {
			return dbFunc(sql)
		}

		return dbFunc(sql, param[0])
	}

	if (rtk != reflect.Slice) && (rtk != reflect.Array) {
		return dbFunc(sql, param[0])
	}

	params := []interface{}{}
	s := reflect.ValueOf(param[0])
	for i := 0; i < s.Len(); i++ {
		params = append(params, s.Index(i).Interface())
	}

	return dbFunc(sql, params...)
}
